{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import glob\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-2-17bb7203622b>:1: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.config.list_physical_devices('GPU')` instead.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.test.is_gpu_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "keras = tf.keras\n",
    "layers = tf.keras.layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image_path = glob.glob('./dataset/dc_2000/train/*/*.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image_label = [int(p.split('\\\\')[1] == 'cat') for p in train_image_path]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_preprosess_image(path,label):\n",
    "    image = tf.io.read_file(path)\n",
    "    image = tf.image.decode_jpeg(image,channels=3)\n",
    "    image = tf.image.resize(image,[256,256])\n",
    "    image = tf.cast(image,tf.float32)\n",
    "    image = image/255\n",
    "    return image,label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image_ds = tf.data.Dataset.from_tensor_slices((train_image_path,train_image_label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "AUTOTUNE = tf.data.experimental.AUTOTUNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image_ds = train_image_ds.map(load_preprosess_image,num_parallel_calls=AUTOTUNE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image_ds = train_image_ds.shuffle(train_count).repeat().batch(BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ParallelMapDataset shapes: ((256, 256, 3), ()), types: (tf.float32, tf.int32)>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_image_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "train_count = len(train_image_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_image_path = glob.glob('./dataset/dc_2000/test/*/*.jpg')\n",
    "test_image_label = [int(p.split('\\\\')[1] == 'cat') for p in test_image_path]\n",
    "test_image_ds = tf.data.Dataset.from_tensor_slices((test_image_path,test_image_label))\n",
    "test_image_ds = test_image_ds.map(load_preprosess_image,num_parallel_calls=AUTOTUNE)\n",
    "test_image_ds = test_image_ds.repeat().batch(BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1000"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_count = len(test_image_path)\n",
    "test_count"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "keras内置经典网络实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "covn_base = keras.applications.VGG16(weights='imagenet',include_top=False)\n",
    "# weights='imagenet'使用在ImageNet上预训练好的网络，若等于null表示直接使用VGG网络\n",
    "# include_top是否包含最后的输出层（预训练好的分类器）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vgg16\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         [(None, None, None, 3)]   0         \n",
      "_________________________________________________________________\n",
      "block1_conv1 (Conv2D)        (None, None, None, 64)    1792      \n",
      "_________________________________________________________________\n",
      "block1_conv2 (Conv2D)        (None, None, None, 64)    36928     \n",
      "_________________________________________________________________\n",
      "block1_pool (MaxPooling2D)   (None, None, None, 64)    0         \n",
      "_________________________________________________________________\n",
      "block2_conv1 (Conv2D)        (None, None, None, 128)   73856     \n",
      "_________________________________________________________________\n",
      "block2_conv2 (Conv2D)        (None, None, None, 128)   147584    \n",
      "_________________________________________________________________\n",
      "block2_pool (MaxPooling2D)   (None, None, None, 128)   0         \n",
      "_________________________________________________________________\n",
      "block3_conv1 (Conv2D)        (None, None, None, 256)   295168    \n",
      "_________________________________________________________________\n",
      "block3_conv2 (Conv2D)        (None, None, None, 256)   590080    \n",
      "_________________________________________________________________\n",
      "block3_conv3 (Conv2D)        (None, None, None, 256)   590080    \n",
      "_________________________________________________________________\n",
      "block3_pool (MaxPooling2D)   (None, None, None, 256)   0         \n",
      "_________________________________________________________________\n",
      "block4_conv1 (Conv2D)        (None, None, None, 512)   1180160   \n",
      "_________________________________________________________________\n",
      "block4_conv2 (Conv2D)        (None, None, None, 512)   2359808   \n",
      "_________________________________________________________________\n",
      "block4_conv3 (Conv2D)        (None, None, None, 512)   2359808   \n",
      "_________________________________________________________________\n",
      "block4_pool (MaxPooling2D)   (None, None, None, 512)   0         \n",
      "_________________________________________________________________\n",
      "block5_conv1 (Conv2D)        (None, None, None, 512)   2359808   \n",
      "_________________________________________________________________\n",
      "block5_conv2 (Conv2D)        (None, None, None, 512)   2359808   \n",
      "_________________________________________________________________\n",
      "block5_conv3 (Conv2D)        (None, None, None, 512)   2359808   \n",
      "_________________________________________________________________\n",
      "block5_pool (MaxPooling2D)   (None, None, None, 512)   0         \n",
      "=================================================================\n",
      "Total params: 14,714,688\n",
      "Trainable params: 14,714,688\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "covn_base.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.Sequential()\n",
    "model.add(covn_base)\n",
    "model.add(layers.GlobalAveragePooling2D())\n",
    "model.add(layers.Dense(512,activation='relu'))\n",
    "model.add(layers.Dense(1,activation='sigmoid'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "vgg16 (Functional)           (None, None, None, 512)   14714688  \n",
      "_________________________________________________________________\n",
      "global_average_pooling2d (Gl (None, 512)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 512)               262656    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 1)                 513       \n",
      "=================================================================\n",
      "Total params: 14,977,857\n",
      "Trainable params: 14,977,857\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保持不动conv_base中的权重,将其设置为不可训练\n",
    "covn_base.trainable = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "vgg16 (Functional)           (None, None, None, 512)   14714688  \n",
      "_________________________________________________________________\n",
      "global_average_pooling2d (Gl (None, 512)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 512)               262656    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 1)                 513       \n",
      "=================================================================\n",
      "Total params: 14,977,857\n",
      "Trainable params: 263,169\n",
      "Non-trainable params: 14,714,688\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(optimizer=keras.optimizers.Adam(lr=0.0005),\n",
    "             loss='binary_crossentropy',\n",
    "             metrics=['acc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/4\n",
      "62/62 [==============================] - 1231s 20s/step - loss: 0.4629 - acc: 0.7888 - val_loss: 0.3068 - val_acc: 0.8891\n",
      "Epoch 2/4\n",
      "62/62 [==============================] - 1194s 19s/step - loss: 0.2926 - acc: 0.8805 - val_loss: 0.2466 - val_acc: 0.9042\n",
      "Epoch 3/4\n",
      "62/62 [==============================] - 1179s 19s/step - loss: 0.2389 - acc: 0.9068 - val_loss: 0.3127 - val_acc: 0.8538\n",
      "Epoch 4/4\n",
      "62/62 [==============================] - 1178s 19s/step - loss: 0.2278 - acc: 0.9068 - val_loss: 0.2979 - val_acc: 0.8720\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(\n",
    "    train_image_ds,\n",
    "    epochs=4,\n",
    "    validation_data=test_image_ds,\n",
    "    steps_per_epoch=train_count//BATCH_SIZE,\n",
    "    validation_steps=test_count//BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 解冻卷积层\n",
    "covn_base.trainable = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "19"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(covn_base.layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 调整最后三层\n",
    "fine_tune_at = -3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 除了最后三层将其他所有层设置为不可训练\n",
    "for layer in covn_base.layers[:fine_tune_at]:\n",
    "    layer.trainable = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将学习速率调小，下探极值\n",
    "model.compile(loss='binary_crossentropy',\n",
    "             optimizer=tf.keras.optimizers.Adam(lr=0.0005/10),\n",
    "             metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/14\n",
      "62/62 [==============================] - 1265s 20s/step - loss: 0.1942 - accuracy: 0.9254 - val_loss: 0.1455 - val_accuracy: 0.9375\n",
      "Epoch 6/14\n",
      "62/62 [==============================] - 1200s 19s/step - loss: 0.0928 - accuracy: 0.9693 - val_loss: 0.1574 - val_accuracy: 0.9365\n",
      "Epoch 7/14\n",
      "62/62 [==============================] - 1193s 19s/step - loss: 0.0519 - accuracy: 0.9849 - val_loss: 0.1220 - val_accuracy: 0.9546\n",
      "Epoch 8/14\n",
      "62/62 [==============================] - 1190s 19s/step - loss: 0.0265 - accuracy: 0.9960 - val_loss: 0.1323 - val_accuracy: 0.9516\n",
      "Epoch 9/14\n",
      "62/62 [==============================] - 1194s 19s/step - loss: 0.0173 - accuracy: 0.9975 - val_loss: 0.1292 - val_accuracy: 0.9476\n",
      "Epoch 10/14\n",
      "62/62 [==============================] - 1190s 19s/step - loss: 0.0124 - accuracy: 0.9975 - val_loss: 0.1168 - val_accuracy: 0.9556\n",
      "Epoch 11/14\n",
      "62/62 [==============================] - 1190s 19s/step - loss: 0.0074 - accuracy: 1.0000 - val_loss: 0.1379 - val_accuracy: 0.9496\n",
      "Epoch 12/14\n",
      "62/62 [==============================] - 1195s 19s/step - loss: 0.0027 - accuracy: 1.0000 - val_loss: 0.1439 - val_accuracy: 0.9536\n",
      "Epoch 13/14\n",
      "62/62 [==============================] - 1191s 19s/step - loss: 0.0021 - accuracy: 1.0000 - val_loss: 0.1416 - val_accuracy: 0.9526\n",
      "Epoch 14/14\n",
      "62/62 [==============================] - 1205s 19s/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.1426 - val_accuracy: 0.9516\n"
     ]
    }
   ],
   "source": [
    "# 初始训练epoch\n",
    "initial_epochs = 4\n",
    "# 微调后训练epoch\n",
    "fine_tune_epochs = 10\n",
    "total_epochs = initial_epochs+fine_tune_epochs\n",
    "\n",
    "history = model.fit(\n",
    "    train_image_ds,\n",
    "    epochs=total_epochs,\n",
    "    initial_epoch=initial_epochs,\n",
    "    validation_data=test_image_ds,\n",
    "    steps_per_epoch=train_count//BATCH_SIZE,\n",
    "    validation_steps=test_count//BATCH_SIZE\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
